use std::ffi::c_void;
use std::mem;
use std::ptr;
use windows::Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryW};
use windows::core::HSTRING;

#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum ImageMitigationPolicy {
    ImageDepPolicy,
    ImageAslrPolicy,
    ImageDynamicCodePolicy,
    ImageStrictHandleCheckPolicy,
    ImageSystemCallDisablePolicy,
    ImageMitigationOptionsMask,
    ImageExtensionPointDisablePolicy,
    ImageControlFlowGuardPolicy,
    ImageSignaturePolicy,
    ImageFontDisablePolicy,
    ImageImageLoadPolicy,
    ImagePayloadRestrictionPolicy,
    ImageChildProcessPolicy,
    ImageSehopPolicy,
    ImageHeapPolicy,
    ImageUserShadowStackPolicy,
    MaxImageMitigationPolicy,
}

#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum RtlImageMitigationOptionState {
    RtlMitigationOptionStateNotConfigured = 0,
    RtlMitigationOptionStateOn = 1,
    RtlMitigationOptionStateOff = 2,
    RtlMitigationOptionStateForce = 4,
    RtlMitigationOptionStateOption = 8,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationPolicy {
    pub policy_state: u64,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationAslrPolicy {
    pub force_relocate_images: RtlImageMitigationPolicy,
    pub bottom_up_randomization: RtlImageMitigationPolicy,
    pub high_entropy_randomization: RtlImageMitigationPolicy,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationDepPolicy {
    pub dep: RtlImageMitigationPolicy,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationStrictHandleCheckPolicy {
    pub strict_handle_checks: RtlImageMitigationPolicy,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationSystemCallDisablePolicy {
    pub block_win32k_system_calls: RtlImageMitigationPolicy,
    pub block_fsctl_system_calls: RtlImageMitigationPolicy,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationExtensionPointDisablePolicy {
    pub disable_extension_points: RtlImageMitigationPolicy,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationDynamicCodePolicy {
    pub block_dynamic_code: RtlImageMitigationPolicy,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationControlFlowGuardPolicy {
    pub control_flow_guard: RtlImageMitigationPolicy,
    pub strict_control_flow_guard: RtlImageMitigationPolicy,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationBinarySignaturePolicy {
    pub block_non_microsoft_signed_binaries: RtlImageMitigationPolicy,
    pub enforce_signing_on_module_dependencies: RtlImageMitigationPolicy,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationFontDisablePolicy {
    pub disable_non_system_fonts: RtlImageMitigationPolicy,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationImageLoadPolicy {
    pub block_remote_image_loads: RtlImageMitigationPolicy,
    pub block_low_label_image_loads: RtlImageMitigationPolicy,
    pub prefer_system32: RtlImageMitigationPolicy,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationPayloadRestrictionPolicy {
    pub enable_export_address_filter: RtlImageMitigationPolicy,
    pub enable_export_address_filter_plus: RtlImageMitigationPolicy,
    pub enable_import_address_filter: RtlImageMitigationPolicy,
    pub enable_rop_stack_pivot: RtlImageMitigationPolicy,
    pub enable_rop_caller_check: RtlImageMitigationPolicy,
    pub enable_rop_sim_exec: RtlImageMitigationPolicy,
    pub module_list: [u16; 512],
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationSehopPolicy {
    pub sehop: RtlImageMitigationPolicy,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationHeapPolicy {
    pub terminate_on_heap_errors: RtlImageMitigationPolicy,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationChildProcessPolicy {
    pub disallow_child_process_creation: RtlImageMitigationPolicy,
}

#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct RtlImageMitigationUserShadowStackPolicy {
    pub user_shadow_stack: RtlImageMitigationPolicy,
    pub set_context_ip_validation: RtlImageMitigationPolicy,
    pub block_non_cet_binaries: RtlImageMitigationPolicy,
}

pub struct Ntdll;

impl Ntdll {
    pub fn rtl_query_image_mitigation_policy(
        image_path: Option<&str>,
        policy: ImageMitigationPolicy,
        flags: u32,
        buffer: *mut c_void,
        buffer_size: u32,
    ) -> i32 {
        unsafe {
            let ntdll = LoadLibraryW(&HSTRING::from("ntdll.dll")).unwrap();
            let proc_addr =
                GetProcAddress(ntdll, windows::core::s!("RtlQueryImageMitigationPolicy"));

            if let Some(proc_addr) = proc_addr {
                let func: extern "system" fn(*const u16, u32, u32, *mut c_void, u32) -> i32 =
                    mem::transmute(proc_addr);

                let path_ptr = if let Some(path) = image_path {
                    let wide_path: Vec<u16> =
                        path.encode_utf16().chain(std::iter::once(0)).collect();
                    wide_path.as_ptr()
                } else {
                    ptr::null()
                };

                func(path_ptr, policy as u32, flags, buffer, buffer_size)
            } else {
                -1
            }
        }
    }

    pub fn rtl_set_image_mitigation_policy(
        image_path: Option<&str>,
        policy: ImageMitigationPolicy,
        flags: u32,
        buffer: *const c_void,
        buffer_size: u32,
    ) -> i32 {
        unsafe {
            let ntdll = LoadLibraryW(&HSTRING::from("ntdll.dll")).unwrap();
            let proc_addr = GetProcAddress(ntdll, windows::core::s!("RtlSetImageMitigationPolicy"));

            if let Some(proc_addr) = proc_addr {
                let func: extern "system" fn(*const u16, u32, u32, *const c_void, u32) -> i32 =
                    mem::transmute(proc_addr);

                let path_ptr = if let Some(path) = image_path {
                    let wide_path: Vec<u16> =
                        path.encode_utf16().chain(std::iter::once(0)).collect();
                    wide_path.as_ptr()
                } else {
                    ptr::null()
                };

                func(path_ptr, policy as u32, flags, buffer, buffer_size)
            } else {
                -1
            }
        }
    }

    pub fn query_image_mitigation_policy<T>(
        image_path: Option<&str>,
        policy_type: ImageMitigationPolicy,
        flags: u32,
    ) -> Result<T, i32>
    where
        T: Default + Copy,
    {
        let mut policy: T = Default::default();
        let buffer_ptr = &mut policy as *mut T as *mut c_void;
        let buffer_size = mem::size_of::<T>() as u32;

        let result = Self::rtl_query_image_mitigation_policy(
            image_path,
            policy_type,
            flags,
            buffer_ptr,
            buffer_size,
        );

        if result == 0 { Ok(policy) } else { Err(result) }
    }

    pub fn set_image_mitigation_policy<T>(
        image_path: Option<&str>,
        policy_type: ImageMitigationPolicy,
        flags: u32,
        policy: &T,
    ) -> i32
    where
        T: Copy,
    {
        let buffer_ptr = policy as *const T as *const c_void;
        let buffer_size = mem::size_of::<T>() as u32;

        Self::rtl_set_image_mitigation_policy(
            image_path,
            policy_type,
            flags,
            buffer_ptr,
            buffer_size,
        )
    }
}

impl Default for RtlImageMitigationPolicy {
    fn default() -> Self {
        Self { policy_state: 0 }
    }
}

impl Default for RtlImageMitigationAslrPolicy {
    fn default() -> Self {
        Self {
            force_relocate_images: Default::default(),
            bottom_up_randomization: Default::default(),
            high_entropy_randomization: Default::default(),
        }
    }
}

impl Default for RtlImageMitigationDepPolicy {
    fn default() -> Self {
        Self {
            dep: Default::default(),
        }
    }
}

impl Default for RtlImageMitigationStrictHandleCheckPolicy {
    fn default() -> Self {
        Self {
            strict_handle_checks: Default::default(),
        }
    }
}

impl Default for RtlImageMitigationSystemCallDisablePolicy {
    fn default() -> Self {
        Self {
            block_win32k_system_calls: Default::default(),
            block_fsctl_system_calls: Default::default(),
        }
    }
}

impl Default for RtlImageMitigationExtensionPointDisablePolicy {
    fn default() -> Self {
        Self {
            disable_extension_points: Default::default(),
        }
    }
}

impl Default for RtlImageMitigationDynamicCodePolicy {
    fn default() -> Self {
        Self {
            block_dynamic_code: Default::default(),
        }
    }
}

impl Default for RtlImageMitigationControlFlowGuardPolicy {
    fn default() -> Self {
        Self {
            control_flow_guard: Default::default(),
            strict_control_flow_guard: Default::default(),
        }
    }
}

impl Default for RtlImageMitigationBinarySignaturePolicy {
    fn default() -> Self {
        Self {
            block_non_microsoft_signed_binaries: Default::default(),
            enforce_signing_on_module_dependencies: Default::default(),
        }
    }
}

impl Default for RtlImageMitigationFontDisablePolicy {
    fn default() -> Self {
        Self {
            disable_non_system_fonts: Default::default(),
        }
    }
}

impl Default for RtlImageMitigationImageLoadPolicy {
    fn default() -> Self {
        Self {
            block_remote_image_loads: Default::default(),
            block_low_label_image_loads: Default::default(),
            prefer_system32: Default::default(),
        }
    }
}

impl Default for RtlImageMitigationPayloadRestrictionPolicy {
    fn default() -> Self {
        Self {
            enable_export_address_filter: Default::default(),
            enable_export_address_filter_plus: Default::default(),
            enable_import_address_filter: Default::default(),
            enable_rop_stack_pivot: Default::default(),
            enable_rop_caller_check: Default::default(),
            enable_rop_sim_exec: Default::default(),
            module_list: [0; 512],
        }
    }
}

impl Default for RtlImageMitigationSehopPolicy {
    fn default() -> Self {
        Self {
            sehop: Default::default(),
        }
    }
}

impl Default for RtlImageMitigationHeapPolicy {
    fn default() -> Self {
        Self {
            terminate_on_heap_errors: Default::default(),
        }
    }
}

impl Default for RtlImageMitigationChildProcessPolicy {
    fn default() -> Self {
        Self {
            disallow_child_process_creation: Default::default(),
        }
    }
}

impl Default for RtlImageMitigationUserShadowStackPolicy {
    fn default() -> Self {
        Self {
            user_shadow_stack: Default::default(),
            set_context_ip_validation: Default::default(),
            block_non_cet_binaries: Default::default(),
        }
    }
}
